{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pose_estimator_2d import openpose_estimator\n",
    "from pose_estimator_3d import estimator_3d\n",
    "from utils import smooth, vis, camera\n",
    "from bvh_skeleton import openpose_skeleton, h36m_skeleton, cmu_skeleton\n",
    "\n",
    "import cv2\n",
    "import importlib\n",
    "import numpy as np\n",
    "import os\n",
    "from pathlib import Path\n",
    "from IPython.display import HTML"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Initialize 2d pose estimator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# more 2d pose estimators like HRNet, PoseResNet, CPN, etc., will be added later\n",
    "e2d = openpose_estimator.OpenPoseEstimator(model_folder='/openpose/models/') # set model_folder to /path/to/openpose/models"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Estimate 2D pose from video"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "video_file = Path('miscs/cxk.mp4') # video file to process\n",
    "output_dir = Path(f'miscs/{video_file.stem}_cache')\n",
    "if not output_dir.exists():\n",
    "    os.makedirs(output_dir)\n",
    "    \n",
    "cap = cv2.VideoCapture(str(video_file))\n",
    "keypoints_list = []\n",
    "img_width, img_height = None, None\n",
    "while True:\n",
    "    ret, frame = cap.read()\n",
    "    if not ret:\n",
    "        break\n",
    "    img_height = frame.shape[0]\n",
    "    img_width = frame.shape[1]\n",
    "    \n",
    "    # returned shape will be (num_of_human, 25, 3)\n",
    "    # last dimension includes (x, y, confidence)\n",
    "    keypoints = e2d.estimate(img_list=[frame])[0]\n",
    "    if not isinstance(keypoints, np.ndarray) or len(keypoints.shape) != 3:\n",
    "        # failed to detect human\n",
    "        keypoints_list.append(None)\n",
    "    else:\n",
    "        # we assume that the image only contains 1 person\n",
    "        # multi-person video needs some extra processes like grouping\n",
    "        # maybe we will implemented it in the future\n",
    "        keypoints_list.append(keypoints[0])\n",
    "cap.release()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Process 2D pose"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# filter out failed result\n",
    "keypoints_list = smooth.filter_missing_value(\n",
    "    keypoints_list=keypoints_list,\n",
    "    method='ignore' # interpolation method will be implemented later\n",
    ")\n",
    "\n",
    "# smooth process will be implemented later\n",
    "\n",
    "# save 2d pose result\n",
    "pose2d = np.stack(keypoints_list)[:, :, :2]\n",
    "pose2d_file = Path(output_dir / '2d_pose.npy')\n",
    "np.save(pose2d_file, pose2d)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Visualize 2D pose"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cap = cv2.VideoCapture(str(video_file))\n",
    "vis_result_dir = output_dir / '2d_pose_vis' # path to save the visualized images\n",
    "if not vis_result_dir.exists():\n",
    "    os.makedirs(vis_result_dir)\n",
    "    \n",
    "op_skel = openpose_skeleton.OpenPoseSkeleton()\n",
    "\n",
    "for i, keypoints in enumerate(keypoints_list):\n",
    "    ret, frame = cap.read()\n",
    "    if not ret:\n",
    "        break\n",
    "    \n",
    "    # keypoint whose detect confidence under kp_thresh will not be visualized\n",
    "    vis.vis_2d_keypoints(\n",
    "        keypoints=keypoints,\n",
    "        img=frame,\n",
    "        skeleton=op_skel,\n",
    "        kp_thresh=0.4,\n",
    "        output_file=vis_result_dir / f'{i:04d}.png'\n",
    "    )\n",
    "cap.release()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Initialize 3D pose estimator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "importlib.reload(estimator_3d)\n",
    "e3d = estimator_3d.Estimator3D(\n",
    "    config_file='models/openpose_video_pose_243f/video_pose.yaml',\n",
    "    checkpoint_file='models/openpose_video_pose_243f/best_58.58.pth'\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Estimate 3D pose from 2D pose"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pose2d = np.load(pose2d_file)\n",
    "pose3d = e3d.estimate(pose2d, image_width=img_width, image_height=img_height)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Convert 3D pose from camera coordinates to world coordinates"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "subject = 'S1'\n",
    "cam_id = '55011271'\n",
    "cam_params = camera.load_camera_params('cameras.h5')[subject][cam_id]\n",
    "R = cam_params['R']\n",
    "T = 0\n",
    "azimuth = cam_params['azimuth']\n",
    "\n",
    "pose3d_world = camera.camera2world(pose=pose3d, R=R, T=T)\n",
    "pose3d_world[:, :, 2] -= np.min(pose3d_world[:, :, 2]) # rebase the height\n",
    "\n",
    "pose3d_file = output_dir / '3d_pose.npy'\n",
    "np.save(pose3d_file, pose3d_world)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Visualize 3D pose"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "h36m_skel = h36m_skeleton.H36mSkeleton()\n",
    "gif_file = output_dir / '3d_pose_300.gif' # output format can be .gif or .mp4 \n",
    "\n",
    "ani = vis.vis_3d_keypoints_sequence(\n",
    "    keypoints_sequence=pose3d_world[0:300],\n",
    "    skeleton=h36m_skel,\n",
    "    azimuth=azimuth,\n",
    "    fps=60,\n",
    "    output_file=gif_file\n",
    ")\n",
    "HTML(ani.to_jshtml())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Convert 3D pose to BVH"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bvh_file = output_dir / f'{video_file.stem}.bvh'\n",
    "cmu_skel = cmu_skeleton.CMUSkeleton()\n",
    "channels, header = cmu_skel.poses2bvh(pose3d_world, output_file=bvh_file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "output = 'miscs/h36m_cxk.bvh'\n",
    "h36m_skel = h36m_skeleton.H36mSkeleton()\n",
    "_ = h36m_skel.poses2bvh(pose3d_world, output_file=output)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
